{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## structure for sentence-transformer\n",
    "code from sentence-transformer package"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "import importlib\n",
    "import os\n",
    "import json\n",
    "from collections import OrderedDict\n",
    "\n",
    "\n",
    "def import_from_string(dotted_path):\n",
    "    \"\"\"\n",
    "    Import a dotted module path and return the attribute/class designated by the\n",
    "    last name in the path. Raise ImportError if the import failed.\n",
    "    \"\"\"\n",
    "    try:\n",
    "        module_path, class_name = dotted_path.rsplit('.', 1)\n",
    "    except ValueError:\n",
    "        msg = \"%s doesn't look like a module path\" % dotted_path\n",
    "        raise ImportError(msg)\n",
    "\n",
    "    try:\n",
    "        module = importlib.import_module(dotted_path)\n",
    "    except:\n",
    "        module = importlib.import_module(module_path)\n",
    "\n",
    "    try:\n",
    "        return getattr(module, class_name)\n",
    "    except AttributeError:\n",
    "        msg = 'Module \"%s\" does not define a \"%s\" attribute/class' % (module_path, class_name)\n",
    "        raise ImportError(msg)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "model_path = \"../models/paraphrase-multilingual-MiniLM-L12-v2\""
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "modules_json_path = os.path.join(model_path, 'modules.json')\n",
    "with open(modules_json_path) as fIn:\n",
    "    modules_config = json.load(fIn)\n",
    "\n",
    "modules = OrderedDict()\n",
    "for module_config in modules_config:\n",
    "    module_class = import_from_string(module_config['type'])\n",
    "    module = module_class.load(os.path.join(model_path, module_config['path']))\n",
    "    modules[module_config['name']] = module"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "data": {
      "text/plain": "OrderedDict([('0',\n              Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel ),\n             ('1',\n              Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False}))])"
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "modules"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## model structure\n",
    "`paraphrase-multilingual-MiniLM-L12-v2` consists of two parts.\n",
    "1. `Transformer`\n",
    "2. `Pooling`\n",
    "\n",
    "get more detail information. you can explode the model dir\n",
    "\n",
    "### 1. Transformer"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "import os\n",
    "import json"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../models/paraphrase-multilingual-MiniLM-L12-v2/0_Transformer\n"
     ]
    }
   ],
   "source": [
    "# For fine tuned large model, the model name is \"bert-large-uncased-whole-word-masking-finetuned-squad\". Here we use bert-base for demo.\n",
    "big_model_path = \"../models/paraphrase-multilingual-MiniLM-L12-v2\"\n",
    "\n",
    "modules_json_path = os.path.join(big_model_path, 'modules.json')\n",
    "with open(modules_json_path) as fIn:\n",
    "    modules_config = json.load(fIn)\n",
    "\n",
    "model_path_01 = os.path.join(big_model_path, modules_config[0].get('path'))\n",
    "print(model_path_01)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at ../models/paraphrase-multilingual-MiniLM-L12-v2/0_Transformer and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import (AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer)\n",
    "\n",
    "# Load pretrained model and tokenizer\n",
    "config_class, model_class, tokenizer_class = (AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer)\n",
    "\n",
    "config = config_class.from_pretrained(model_path_01)\n",
    "tokenizer = tokenizer_class.from_pretrained(model_path_01, do_lower_case=True)\n",
    "model = model_class.from_pretrained(model_path_01, from_tf=False, config=config)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [
    {
     "data": {
      "text/plain": "{'input_ids': tensor([[     0,  73014,   1322,      2,      1,      1,      1,      1,      1],\n        [     0,  33600,     31,      2,      1,      1,      1,      1,      1],\n        [     0,    186,  17723,      2,      1,      1,      1,      1,      1],\n        [     0,  73675,   4011, 134195,      4, 178084,   1322, 184120,      2]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],\n        [0, 0, 0, 0, 0, 0, 0, 0, 0],\n        [0, 0, 0, 0, 0, 0, 0, 0, 0],\n        [0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0],\n        [1, 1, 1, 1, 0, 0, 0, 0, 0],\n        [1, 1, 1, 1, 0, 0, 0, 0, 0],\n        [1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentences = ['您好', 'hello', 'be happy', '你吃了吗，我真的好饿']\n",
    "inputs = tokenizer(sentences,\n",
    "                   padding=True,\n",
    "                   truncation=True,\n",
    "                   max_length=512,\n",
    "                   return_tensors=\"pt\")\n",
    "inputs"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])\n"
     ]
    }
   ],
   "source": [
    "# show key of inputs\n",
    "print(inputs.keys())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "outputs": [
    {
     "data": {
      "text/plain": "QuestionAnsweringModelOutput(loss=None, start_logits=tensor([[-0.0498, -0.1204, -0.1239, -0.0501, -0.0204, -0.0208, -0.0204, -0.0149,\n         -0.0158],\n        [-0.0786, -0.1878, -0.1830, -0.0789, -0.0092, -0.0098, -0.0173, -0.0067,\n         -0.0106],\n        [ 0.0054, -0.1231,  0.1735,  0.0049,  0.0053,  0.0055,  0.0026,  0.0032,\n         -0.0023],\n        [-0.0533, -0.1638, -0.0812, -0.1287, -0.1234, -0.0523, -0.0487, -0.0645,\n         -0.0534]], grad_fn=<CopyBackwards>), end_logits=tensor([[ 0.0284, -0.0114,  0.0034,  0.0282, -0.0351, -0.0409, -0.0394, -0.0323,\n         -0.0228],\n        [ 0.0227, -0.0387, -0.0143,  0.0226, -0.0176, -0.0267, -0.0333, -0.0151,\n         -0.0208],\n        [ 0.0335,  0.1232,  0.1914,  0.0335,  0.0006, -0.0024, -0.0033,  0.0069,\n          0.0040],\n        [ 0.1654,  0.1430,  0.2075,  0.1462,  0.1528,  0.2647,  0.1944,  0.2596,\n          0.1655]], grad_fn=<CopyBackwards>), hidden_states=None, attentions=None)"
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()\n",
    "outputs = model.forward(**inputs)\n",
    "outputs"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "odict_keys(['start_logits', 'end_logits'])\n"
     ]
    }
   ],
   "source": [
    "print(outputs.keys())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 2. Pooling"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}